/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.streaming.util;

import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.ClosureCleaner;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
import org.apache.flink.util.Migration;

import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.RunnableFuture;

import static org.mockito.Matchers.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.doAnswer;

/**
 * Extension of {@link OneInputStreamOperatorTestHarness} that allows the operator to get
 * a {@link KeyedStateBackend}.
 */
public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
		extends OneInputStreamOperatorTestHarness<IN, OUT> {

	// in case the operator creates one we store it here so that we
	// can snapshot its state
	private AbstractKeyedStateBackend<?> keyedStateBackend = null;

	// when we restore we keep the state here so that we can call restore
	// when the operator requests the keyed state backend
	private List<KeyedStateHandle> restoredKeyedState = null;

	public KeyedOneInputStreamOperatorTestHarness(
			OneInputStreamOperator<IN, OUT> operator,
			final KeySelector<IN, K> keySelector,
			TypeInformation<K> keyType,
			int maxParallelism,
			int numSubtasks,
			int subtaskIndex) throws Exception {
		super(operator, maxParallelism, numSubtasks, subtaskIndex);

		ClosureCleaner.clean(keySelector, false);
		config.setStatePartitioner(0, keySelector);
		config.setStateKeySerializer(keyType.createSerializer(executionConfig));

		setupMockTaskCreateKeyedBackend();
	}

	public KeyedOneInputStreamOperatorTestHarness(
			OneInputStreamOperator<IN, OUT> operator,
			final KeySelector<IN, K> keySelector,
			TypeInformation<K> keyType) throws Exception {
		this(operator, keySelector, keyType, 1, 1, 0);
	}

	public KeyedOneInputStreamOperatorTestHarness(
			final OneInputStreamOperator<IN, OUT> operator,
			final  KeySelector<IN, K> keySelector,
			final TypeInformation<K> keyType,
			final Environment environment) throws Exception {

		super(operator, environment);

		ClosureCleaner.clean(keySelector, false);
		config.setStatePartitioner(0, keySelector);
		config.setStateKeySerializer(keyType.createSerializer(executionConfig));

		setupMockTaskCreateKeyedBackend();
	}

	private void setupMockTaskCreateKeyedBackend() {

		try {
			doAnswer(new Answer<KeyedStateBackend>() {
				@Override
				public KeyedStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable {

					final TypeSerializer keySerializer = (TypeSerializer) invocationOnMock.getArguments()[0];
					final int numberOfKeyGroups = (Integer) invocationOnMock.getArguments()[1];
					final KeyGroupRange keyGroupRange = (KeyGroupRange) invocationOnMock.getArguments()[2];

					if (keyedStateBackend != null) {
						keyedStateBackend.dispose();
					}

					keyedStateBackend = stateBackend.createKeyedStateBackend(
							mockTask.getEnvironment(),
							new JobID(),
							"test_op",
							keySerializer,
							numberOfKeyGroups,
							keyGroupRange,
							mockTask.getEnvironment().getTaskKvStateRegistry());

					keyedStateBackend.restore(restoredKeyedState);

					return keyedStateBackend;
				}
			}).when(mockTask).createKeyedStateBackend(any(TypeSerializer.class), anyInt(), any(KeyGroupRange.class));
		} catch (Exception e) {
			throw new RuntimeException(e.getMessage(), e);
		}
	}

	/**
	 *
	 */
	@Override
	public StreamStateHandle snapshotLegacy(long checkpointId, long timestamp) throws Exception {
		// simply use an in-memory handle
		MemoryStateBackend backend = new MemoryStateBackend();

		CheckpointStreamFactory streamFactory = backend.createStreamFactory(new JobID(), "test_op");
		CheckpointStreamFactory.CheckpointStateOutputStream outStream =
				streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp);

		if (operator instanceof StreamCheckpointedOperator) {
			((StreamCheckpointedOperator) operator).snapshotState(outStream, checkpointId, timestamp);
		}

		if (keyedStateBackend != null) {
			RunnableFuture<KeyedStateHandle> keyedSnapshotRunnable = keyedStateBackend.snapshot(
					checkpointId,
					timestamp,
					streamFactory,
					CheckpointOptions.forFullCheckpoint());
			if (!keyedSnapshotRunnable.isDone()) {
				Thread runner = new Thread(keyedSnapshotRunnable);
				runner.start();
			}
			outStream.write(1);
			ObjectOutputStream oos = new ObjectOutputStream(outStream);
			oos.writeObject(keyedSnapshotRunnable.get());
			oos.flush();
		} else {
			outStream.write(0);
		}
		return outStream.closeAndGetHandle();
	}

	/**
	 *
	 */
	@Override
	public void restore(StreamStateHandle snapshot) throws Exception {
		try (FSDataInputStream inStream = snapshot.openInputStream()) {

			if (operator instanceof StreamCheckpointedOperator) {
				((StreamCheckpointedOperator) operator).restoreState(inStream);
			}

			byte keyedStatePresent = (byte) inStream.read();
			if (keyedStatePresent == 1) {
				ObjectInputStream ois = new ObjectInputStream(inStream);
				this.restoredKeyedState = Collections.singletonList((KeyedStateHandle) ois.readObject());
			}
		}
	}

	private static boolean hasMigrationHandles(Collection<KeyedStateHandle> allKeyGroupsHandles) {
		for (KeyedStateHandle handle : allKeyGroupsHandles) {
			if (handle instanceof Migration) {
				return true;
			}
		}
		return false;
	}

	public int numKeyedStateEntries() {
		if (keyedStateBackend instanceof HeapKeyedStateBackend) {
			return ((HeapKeyedStateBackend) keyedStateBackend).numStateEntries();
		} else {
			throw new UnsupportedOperationException();
		}
	}

	public <N> int numKeyedStateEntries(N namespace) {
		if (keyedStateBackend instanceof HeapKeyedStateBackend) {
			return ((HeapKeyedStateBackend) keyedStateBackend).numStateEntries(namespace);
		} else {
			throw new UnsupportedOperationException();
		}
	}

	@Override
	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
		if (operatorStateHandles != null) {
			int numKeyGroups = getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks();
			int numSubtasks = getEnvironment().getTaskInfo().getNumberOfParallelSubtasks();
			int subtaskIndex = getEnvironment().getTaskInfo().getIndexOfThisSubtask();

			// create a new OperatorStateHandles that only contains the state for our key-groups

			List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions(
					numKeyGroups,
					numSubtasks);

			KeyGroupRange localKeyGroupRange =
					keyGroupPartitions.get(subtaskIndex);

			restoredKeyedState = null;
			Collection<KeyedStateHandle> managedKeyedState = operatorStateHandles.getManagedKeyedState();
			if (managedKeyedState != null) {

				// if we have migration handles, don't reshuffle state and preserve
				// the migration tag
				if (hasMigrationHandles(managedKeyedState)) {
					List<KeyedStateHandle> result = new ArrayList<>(managedKeyedState.size());
					result.addAll(managedKeyedState);
					restoredKeyedState = result;
				} else {
					restoredKeyedState = StateAssignmentOperation.getKeyedStateHandles(
							managedKeyedState,
							localKeyGroupRange);
				}
			}
		}

		super.initializeState(operatorStateHandles);
	}
}
